# ======================================================================
# Copyright CERFACS (June 2019)
# Contributor: Adrien Suau (adrien.suau@cerfacs.fr)
#
# This software is governed by the CeCILL-B license under French law and
# abiding  by the  rules of  distribution of free software. You can use,
# modify  and/or  redistribute  the  software  under  the  terms  of the
# CeCILL-B license as circulated by CEA, CNRS and INRIA at the following
# URL "http://www.cecill.info".
#
# As a counterpart to the access to  the source code and rights to copy,
# modify and  redistribute granted  by the  license, users  are provided
# only with a limited warranty and  the software's author, the holder of
# the economic rights,  and the  successive licensors  have only limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading,  using, modifying and/or  developing or reproducing  the
# software by the user in light of its specific status of free software,
# that  may mean  that it  is complicated  to manipulate,  and that also
# therefore  means that  it is reserved for  developers and  experienced
# professionals having in-depth  computer knowledge. Users are therefore
# encouraged  to load and  test  the software's  suitability as  regards
# their  requirements  in  conditions  enabling  the  security  of their
# systems  and/or  data to be  ensured and,  more generally,  to use and
# operate it in the same conditions as regards security.
#
# The fact that you  are presently reading this  means that you have had
# knowledge of the CeCILL-B license and that you accept its terms.
# ======================================================================
"""This module implements the LCR optimisation described in
https://arxiv.org/pdf/1710.07345.pdf.

"""
from copy import deepcopy
import logging

from qat.lang.AQASM.program import Program
from qat.lang.AQASM.routines import QRoutine

from qaths.utils.circuit import merge as merge_circuits, repeat as repeat_circuit

logger = logging.getLogger("qaths.simulation.pf.optimisation")

try:
    from qat.core.simutil import optimize_circuit
    from qaths.pbo_rules.graphopt import adapt_to_graphopt
    from qat.graphopt import Graphopt

    def graphopt_optimisation(circ):
        """Adapt the circuit gate-set to graphopt and run graphopt on it.

        :param circ: Circuit to optimise using graphopt.
        :return: a circuit optimised by the graphopt algorithm.
        """
        adapted_circ = adapt_to_graphopt(circ)

        expanded_circuit = optimize_circuit(adapted_circ, Graphopt(expandonly=True))
        optimised_circuit = optimize_circuit(expanded_circuit, Graphopt())
        return optimised_circuit


except ImportError:
    logger.warning(
        "Graphopt not found, you are probably on MyQLM. Disabling routines using it."
    )


def no_optimisation(circ):
    """Identity function, just return the circuit given in parameter."""
    return circ


def lcr_optimise(
    routine: QRoutine, repetitions: int, optimisation_procedure=no_optimisation
):
    """Optimise a routine that should be repeated a large number of times.

    This optimisation is described in https://arxiv.org/pdf/1710.07345.pdf.

    :param routine: The quantum routine that needs to be repeated and optimised.
    :param repetitions: The number of times the given routine should be repeated.
    :param optimisation_procedure: A function taking a circuit as input and returning
        an optimised circuit.
    :return: the `routine` optimised with `optimisation_procedure` and repeated
        `repetitions` times in the worst case, or a more optimised circuit if possible.
    """
    # Create the circuit from the given routine
    prog = Program()
    qbits = prog.qalloc(routine.arity)
    prog.apply(routine, qbits)
    non_optimised_circuit = prog.to_circ()
    print("Size before optimisation: {} gates".format(len(non_optimised_circuit.ops)))

    # Optimise a first time the created circuit, alone.
    O = optimisation_procedure(non_optimised_circuit)
    print("Size of O: {} gates".format(len(O.ops)))

    # Then compute and optimise O², i.e. 2 successive application of the optimised
    # circuit.
    OO = repeat_circuit(O, 2)
    print("Size of O²: {} gates".format(len(OO.ops)))
    LR = optimisation_procedure(OO)
    print("Size of LR: {} gates".format(len(LR.ops)))

    # Find the maximal prefix of LR in OO
    maximal_prefix_index = 0
    for i in range(min(len(O.ops), len(OO.ops))):
        if O.ops[i] != LR.ops[i]:
            maximal_prefix_index = i
            break
    # i is now the index of the first operation that is not the same between O and OO.
    # Create the circuit L and R.
    # First, before making a deep-copy, save the operations elsewhere to avoid copying
    # them.
    lr_ops_save, LR.ops = LR.ops, None
    L, R = deepcopy(LR), deepcopy(LR)
    L.ops = lr_ops_save[:maximal_prefix_index]
    R.ops = lr_ops_save[maximal_prefix_index:]
    LR.ops = lr_ops_save
    print("Size of L: {} gates".format(len(L.ops)))
    print("Size of R: {} gates".format(len(R.ops)))

    # Repeat with O^3:
    OOO = repeat_circuit(O, 3)
    print("Size of O³: {} gates".format(len(OOO.ops)))

    LCR = optimisation_procedure(OOO)
    print("Size of LCR: {} gates".format(len(LCR.ops)))

    # Extract C from LCR
    lcr_ops_save, LCR.ops = LCR.ops, None
    C = deepcopy(LCR)
    C.ops = lcr_ops_save[len(L.ops) : -len(R.ops)]
    LCR.ops = lcr_ops_save

    print("Size of C: {} gates".format(len(C.ops)))

    print(
        "Final circuit has {} gates VS {} gates for non-optimised version ".format(
            len(L.ops) + repetitions * len(C.ops) + len(R.ops),
            repetitions * len(non_optimised_circuit.ops),
        )
        + "and {} for optimised/repeat version.".format(len(O.ops) * repetitions)
    )
    print(
        "Removed {}% of the total gate count.".format(
            int(
                100
                * (len(L.ops) + repetitions * len(C.ops) + len(R.ops))
                / (repetitions * len(non_optimised_circuit))
            )
        )
    )

    # Finally compute the final optimised circuit:
    return merge_circuits(L, repeat_circuit(C, repetitions), R)
